{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Shakespeare in 5 minutes with Cloud TPUs and Keras",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "metadata": {
        "id": "xzpUtDMqmA-x",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Shakespeare in 5 minutes with Cloud TPUs and Keras\n",
        "\n",
        "This notebook demonstrates using Cloud TPUs to build a _language model_: a model that predicts the next character of text given the text so far.  Once our model has been trained we can sample from it to generate new text that looks like the text it was trained on.  In this case we're going to train our network using the combined works of Shakespeare, creating a play-generating robot.\n",
        "\n",
        "Our network outputs something Shakespeare-esque:\n",
        "\n",
        "___\n",
        "<blockquote>\n",
        "Loves that led me no dumbs lack her Berjoy's face with her to-day.  \n",
        "The spirits roar'd; which shames which within his powers  \n",
        "\tWhich tied up remedies lending with occasion,  \n",
        "A loud and Lancaster, stabb'd in me  \n",
        "\tUpon my sword for ever: 'Agripo'er, his days let me free.  \n",
        "\tStop it of that word, be so: at Lear,  \n",
        "\tWhen I did profess the hour-stranger for my life,  \n",
        "\tWhen I did sink to be cried how for aught;  \n",
        "\tSome beds which seeks chaste senses prove burning;  \n",
        "But he perforces seen in her eyes so fast;  \n",
        "And _  \n",
        "</blockquote>\n",
        "___\n",
        "\n",
        "Let's get started on generating our own Shakespeare!  We'll start off with our data generator.  The training data to our model will be snippets from our text file: the _target_ snippet is offset by one character."
      ]
    },
    {
      "metadata": {
        "id": "j8sIXh1DEDDd",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 71
        },
        "outputId": "3635cec2-ddea-45f8-8d4a-92b747a67c73"
      },
      "cell_type": "code",
      "source": [
        "!wget --show-progress --continue -O /content/shakespeare.txt http://www.gutenberg.org/files/100/100-0.txt"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\n",
            "Redirecting output to ‘wget-log’.\n",
            "/content/shakespear 100%[===================>]   5.58M  8.24MB/s    in 0.7s    \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "E3V4V-Jxmuv3",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 251
        },
        "outputId": "f13aed46-6ff4-48e3-ef36-33226197e457"
      },
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import six\n",
        "import tensorflow as tf\n",
        "import time\n",
        "import os\n",
        "\n",
        "# This address identifies the TPU we'll use when configuring TensorFlow.\n",
        "TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']\n",
        "\n",
        "SHAKESPEARE_TXT = '/content/shakespeare.txt'\n",
        "\n",
        "tf.logging.set_verbosity(tf.logging.INFO)\n",
        "\n",
        "def transform(txt, pad_to=None):\n",
        "  # drop any non-ascii characters\n",
        "  output = np.asarray([ord(c) for c in txt if ord(c) < 255], dtype=np.int32)\n",
        "  if pad_to is not None:\n",
        "    output = output[:pad_to]\n",
        "    output = np.concatenate([\n",
        "        np.zeros([pad_to - len(txt)], dtype=np.int32),\n",
        "        output,\n",
        "    ])\n",
        "  return output\n",
        "\n",
        "def training_generator(seq_len=100, batch_size=1024):\n",
        "  \"\"\"A generator yields (source, target) arrays for training.\"\"\"\n",
        "  with tf.gfile.GFile(SHAKESPEARE_TXT, 'r') as f:\n",
        "    txt = f.read()\n",
        "\n",
        "  tf.logging.info('Input text [%d] %s', len(txt), txt[:50])\n",
        "  source = transform(txt)\n",
        "  while True:\n",
        "    offsets = np.random.randint(0, len(source) - seq_len, batch_size)\n",
        "\n",
        "    # Our model uses sparse crossentropy loss, but Keras requires labels\n",
        "    # to have the same rank as the input logits.  We add an empty final\n",
        "    # dimension to account for this.\n",
        "    yield (\n",
        "        np.stack([source[idx:idx + seq_len] for idx in offsets]),\n",
        "        np.expand_dims(\n",
        "            np.stack([source[idx + 1:idx + seq_len + 1] for idx in offsets]),\n",
        "            -1),\n",
        "    )\n",
        "\n",
        "six.next(training_generator(seq_len=10, batch_size=1))"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Input text [5834393] ﻿\r\n",
            "Project Gutenberg’s The Complete Works of Willi\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(array([[103, 101, 114,  13,  10,  32,  32,  32,  32,  79]], dtype=int32),\n",
              " array([[[101],\n",
              "         [114],\n",
              "         [ 13],\n",
              "         [ 10],\n",
              "         [ 32],\n",
              "         [ 32],\n",
              "         [ 32],\n",
              "         [ 32],\n",
              "         [ 79],\n",
              "         [117]]], dtype=int32))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 2
        }
      ]
    },
    {
      "metadata": {
        "id": "Bbb05dNynDrQ",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Building our model\n",
        "\n",
        "Now that we have some data, we can define our model.  We use a simple 2 layer, forward LSTM for this notebook.\n",
        "\n",
        "We make 2 changes from a standard LSTM definition in Keras.  First, we define the `shape` for the input of our model.  This allows TF to infer the shape of the model and satisfy the XLA compiler's static shape requirement.\n",
        "\n",
        "Second, we use a `tf.train.Optimizer` instead of a standard Keras optimizer (Keras optimizer support is still experimental.)"
      ]
    },
    {
      "metadata": {
        "id": "yLEM-fLJlEEt",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "EMBEDDING_DIM = 512\n",
        "\n",
        "def lstm_model(seq_len=100, batch_size=None, stateful=True):\n",
        "  \"\"\"Language model: predict the next word given the current word.\"\"\"\n",
        "  source = tf.keras.Input(\n",
        "      name='seed', shape=(seq_len,), batch_size=batch_size, dtype=tf.int32)\n",
        "\n",
        "  embedding = tf.keras.layers.Embedding(input_dim=256, output_dim=EMBEDDING_DIM)(source)\n",
        "  lstm_1 = tf.keras.layers.LSTM(EMBEDDING_DIM, stateful=stateful, return_sequences=True)(embedding)\n",
        "  lstm_2 = tf.keras.layers.LSTM(EMBEDDING_DIM, stateful=stateful, return_sequences=True)(lstm_1)\n",
        "  predicted_char = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(256, activation='softmax'))(lstm_2)\n",
        "  model = tf.keras.Model(inputs=[source], outputs=[predicted_char])\n",
        "  model.compile(\n",
        "      optimizer=tf.train.RMSPropOptimizer(learning_rate=0.01),\n",
        "      loss='sparse_categorical_crossentropy',\n",
        "      metrics=['sparse_categorical_accuracy'])\n",
        "  return model"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "VzBYDJI0_Tfm",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Training our model\n",
        "\n",
        "The `tf.contrib.tpu.keras_to_tpu_model` function converts our Keras model to an equivalent TPU version.  We can then use our standard `fit`, `predict`, `evaluate` Keras methods to train."
      ]
    },
    {
      "metadata": {
        "id": "ExQ922tfzSGA",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 901
        },
        "outputId": "3e4b89bf-9e9e-4a9a-9ce0-775a39d9f2ce"
      },
      "cell_type": "code",
      "source": [
        "tf.keras.backend.clear_session()\n",
        "\n",
        "training_model = lstm_model(seq_len=100, batch_size=128, stateful=False)\n",
        "\n",
        "tpu_model = tf.contrib.tpu.keras_to_tpu_model(\n",
        "    training_model,\n",
        "    strategy=tf.contrib.tpu.TPUDistributionStrategy(\n",
        "        tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))\n",
        "\n",
        "tpu_model.fit_generator(\n",
        "    training_generator(seq_len=100, batch_size=1024),\n",
        "    steps_per_epoch=100,\n",
        "    epochs=10,\n",
        ")\n",
        "tpu_model.save_weights('/tmp/bard.h5', overwrite=True)"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Querying Tensorflow master (b'grpc://10.118.7.82:8470') for TPU system metadata.\n",
            "INFO:tensorflow:Found TPU system:\n",
            "INFO:tensorflow:*** Num TPU Cores: 8\n",
            "INFO:tensorflow:*** Num TPU Workers: 1\n",
            "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 6638032082838577689)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 4873016205556938351)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 5471416470704720555)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 6926907012367290755)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 4354044869524745214)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 16938815156612161417)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 6950840633641785585)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 16768792233987872678)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 12675314328015960221)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 16547454787033470948)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 10102547023578588130)\n",
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 17447963411858084871)\n",
            "WARNING:tensorflow:tpu_model (from tensorflow.contrib.tpu.python.tpu.keras_support) is experimental and may change or be removed at any time, and without warning.\n",
            "INFO:tensorflow:Connecting to: b'grpc://10.118.7.82:8470'\n",
            "Epoch 1/10\n",
            "INFO:tensorflow:Input text [5834393] ﻿\n",
            "Project Gutenberg’s The Complete Works of Willi\n",
            "INFO:tensorflow:New input shapes; (re-)compiling: mode=train, [TensorSpec(shape=(128, 100), dtype=tf.int32, name='seed0'), TensorSpec(shape=(128, 100, 1), dtype=tf.float32, name='time_distributed_target_10')]\n",
            "INFO:tensorflow:Overriding default placeholder.\n",
            "INFO:tensorflow:Remapping placeholder for seed\n",
            "INFO:tensorflow:Started compiling\n",
            "INFO:tensorflow:Finished compiling. Time elapsed: 3.9678761959075928 secs\n",
            "INFO:tensorflow:Setting weights on TPU model.\n",
            "100/100 [==============================] - 27s 268ms/step - loss: 4.5166 - sparse_categorical_accuracy: 0.1821\n",
            "Epoch 2/10\n",
            "100/100 [==============================] - 17s 167ms/step - loss: 3.3698 - sparse_categorical_accuracy: 0.1959\n",
            "Epoch 3/10\n",
            "100/100 [==============================] - 17s 167ms/step - loss: 2.1526 - sparse_categorical_accuracy: 0.3880\n",
            "Epoch 4/10\n",
            "100/100 [==============================] - 17s 168ms/step - loss: 1.5054 - sparse_categorical_accuracy: 0.5477\n",
            "Epoch 5/10\n",
            "100/100 [==============================] - 17s 168ms/step - loss: 1.3133 - sparse_categorical_accuracy: 0.5975\n",
            "Epoch 6/10\n",
            "100/100 [==============================] - 17s 170ms/step - loss: 1.2355 - sparse_categorical_accuracy: 0.6173\n",
            "Epoch 7/10\n",
            "100/100 [==============================] - 17s 167ms/step - loss: 1.1957 - sparse_categorical_accuracy: 0.6273\n",
            "Epoch 8/10\n",
            "100/100 [==============================] - 17s 169ms/step - loss: 1.1664 - sparse_categorical_accuracy: 0.6355\n",
            "Epoch 9/10\n",
            "100/100 [==============================] - 17s 167ms/step - loss: 1.1466 - sparse_categorical_accuracy: 0.6407\n",
            "Epoch 10/10\n",
            "100/100 [==============================] - 17s 167ms/step - loss: 1.1344 - sparse_categorical_accuracy: 0.6437\n",
            "INFO:tensorflow:Copying TPU weights to the CPU\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "TCBtcpZkykSf",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Running predictions with our model\n",
        "\n",
        "We've trained our model, now we can run predictions through it to generate \"Shakespeare\"!  We provide a seed sentence to get our model started, and then sample 250 characters from it.\n",
        "\n",
        "We'll make 5 predictions from the initial seed."
      ]
    },
    {
      "metadata": {
        "id": "tU7M-EGGxR3E",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1241
        },
        "outputId": "18dce114-3500-4258-e628-e6326400fe1f"
      },
      "cell_type": "code",
      "source": [
        "BATCH_SIZE = 5\n",
        "PREDICT_LEN = 250\n",
        "\n",
        "# Keras requires the batch size be specified ahead of time for stateful models.\n",
        "# We use a sequence length of 1, as we will be feeding in one character at a \n",
        "# time and predicting the next character.\n",
        "prediction_model = lstm_model(seq_len=1, batch_size=BATCH_SIZE, stateful=True)\n",
        "prediction_model.load_weights('/tmp/bard.h5')\n",
        "\n",
        "# We seed the model with our initial string, copied BATCH_SIZE times\n",
        "\n",
        "seed_txt = 'Looks it not like the king?  Verily, we must go! '\n",
        "seed = transform(seed_txt)\n",
        "seed = np.repeat(np.expand_dims(seed, 0), BATCH_SIZE, axis=0)\n",
        "\n",
        "# First, run the seed forward to prime the state of the model.\n",
        "prediction_model.reset_states()\n",
        "for i in range(len(seed_txt) - 1):\n",
        "  prediction_model.predict(seed[:, i:i + 1])\n",
        "\n",
        "# Now we can accumulate predictions!\n",
        "predictions = [seed[:, -1:]]\n",
        "for i in range(PREDICT_LEN):\n",
        "  last_word = predictions[-1]\n",
        "  next_probits = prediction_model.predict(last_word)[:, 0, :]\n",
        "  \n",
        "  # sample from our output distribution\n",
        "  next_idx = [\n",
        "      np.random.choice(256, p=next_probits[i])\n",
        "      for i in range(BATCH_SIZE)\n",
        "  ]\n",
        "  predictions.append(np.asarray(next_idx, dtype=np.int32))\n",
        "  \n",
        "\n",
        "for i in range(BATCH_SIZE):\n",
        "  print('PREDICTION %d\\n\\n' % i)\n",
        "  p = [predictions[j][i] for j in range(PREDICT_LEN)]\n",
        "  generated = ''.join([chr(c) for c in p])\n",
        "  print(generated)\n",
        "  print()\n",
        "  assert len(generated) == PREDICT_LEN, 'Generated text too short'"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "PREDICTION 0\n",
            "\n",
            "\n",
            " vouchsafe to.\r\n",
            "\r\n",
            "[Enter Iailors dastard.]\r\n",
            "And yet these doublest masters young Snowing in a tongue bend interchanged factions of more do,\r\n",
            "Sett, appartives.\r\n",
            "\r\n",
            "THERSITES.\r\n",
            "Nor tis what a separate main, were they.\r\n",
            "\r\n",
            "POLONIUS.\r\n",
            "O, I could touch them\n",
            "\n",
            "PREDICTION 1\n",
            "\n",
            "\n",
            " Pay thee in them lo waste.\r\n",
            "Your head no hire and ground of all Truan:\r\n",
            "For in the glorious cou,\r\n",
            "Blunk, thence and laud, repentance, wrapping;\r\n",
            "For what you were there.\r\n",
            "\r\n",
            "ROMEO.\r\n",
            "A very lamentable if with a drunk.\r\n",
            "\r\n",
            "MONTAGUE.\r\n",
            "Yes, madam.\r\n",
            "\r\n",
            "PALA\n",
            "\n",
            "PREDICTION 2\n",
            "\n",
            "\n",
            " If sure you should say, although\r\n",
            "and means: the same, sir! It shall as the Lies; I am out to change.\r\n",
            "Unclaimer that hath plead'd drink,       220\r\n",
            "Wherein their sacred life to bed.\r\n",
            "Why wouldst thou take my lady? O, go indeed\r\n",
            "That piting at the r\n",
            "\n",
            "PREDICTION 3\n",
            "\n",
            "\n",
            " stands death]\r\n",
            "So lady, we are unsaying lambs, and oftendance\r\n",
            "Which, then; wheres Edic died all that?\r\n",
            "\r\n",
            "WOMAN.\r\n",
            "Then I will be seen.\r\n",
            "\r\n",
            "DIOMEDES.\r\n",
            "Tis the last,\r\n",
            "And so within a mosey of the mines,\r\n",
            "Look\r\n",
            "Farry this to then, voice he living propos\n",
            "\n",
            "PREDICTION 4\n",
            "\n",
            "\n",
            " I cannot way.\r\n",
            "You do not spent till to Thebs. Not as now\r\n",
            "We were hundred; and I feareth much well condemn'd to bear\r\n",
            "Though her mothers heads in this image slow:\r\n",
            "It was a soul that hath a beggar;\r\n",
            "For this absence doth sack again, thunder-boist's\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}